{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# Batch Normalization\n",
    ":label:`sec_batch_norm`\n",
    "\n",
    "Training deep neural nets is difficult.\n",
    "And getting them to converge in a reasonable amount of time can be tricky.\n",
    "In this section, we describe batch normalization (BN)\n",
    ":cite:`Ioffe.Szegedy.2015`, a popular and effective technique\n",
    "that consistently accelerates the convergence of deep nets.\n",
    "Together with residual blocks—covered in :numref:`sec_resnet`—BN\n",
    "has made it possible for practitioners\n",
    "to routinely train networks with over 100 layers.\n",
    "\n",
    "\n",
    "\n",
    "## Training Deep Networks\n",
    "\n",
    "To motivate batch normalization, let us review\n",
    "a few practical challenges that arise\n",
    "when training ML models and neural nets in particular.\n",
    "\n",
    "1. Choices regarding data preprocessing often\n",
    "   make an enormous difference in the final results.\n",
    "   Recall our application of multilayer perceptrons\n",
    "   to predicting house prices (:numref:`sec_kaggle_house`).\n",
    "   Our first step when working with real data\n",
    "   was to standardize our input features\n",
    "   to each have a mean of *zero* and variance of *one*.\n",
    "   Intuitively, this standardization plays nicely with our optimizers\n",
    "   because it puts the parameters a-priori at a similar scale.\n",
    "2. For a typical MLP or CNN, as we train,\n",
    "   the activations in intermediate layers\n",
    "   may take values with widely varying magnitudes—both\n",
    "   along the layers from the input to the output,\n",
    "   across nodes in the same layer,\n",
    "   and over time due to our updates to the model's parameters.\n",
    "   The inventors of batch normalization postulated informally\n",
    "   that this drift in the distribution of activations\n",
    "   could hamper the convergence of the network.\n",
    "   Intuitively, we might conjecture that if one\n",
    "   layer has activation values that are 100x that of another layer,\n",
    "   this might necessitate compensatory adjustments in the learning rates.\n",
    "3. Deeper networks are complex and easily capable of overfitting.\n",
    "   This means that regularization becomes more critical.\n",
    "\n",
    "Batch normalization is applied to individual layers\n",
    "(optionally, to all of them) and works as follows:\n",
    "In each training iteration,\n",
    "we first normalize the inputs (of batch normalization)\n",
    "by subtracting their mean and\n",
    "dividing by their standard deviation,\n",
    "where both are estimated based on the statistics of the current minibatch.\n",
    "Next, we apply a scaling coefficient and a scaling offset.\n",
    "It is precisely due to this *normalization* based on *batch* statistics\n",
    "that *batch normalization* derives its name.\n",
    "\n",
    "Note that if we tried to apply BN with minibatches of size $1$,\n",
    "we would not be able to learn anything.\n",
    "That is because after subtracting the means,\n",
    "each hidden node would take value $0$!\n",
    "As you might guess, since we are devoting a whole section to BN,\n",
    "with large enough minibatches, the approach proves effective and stable.\n",
    "One takeaway here is that when applying BN,\n",
    "the choice of minibatch size may be\n",
    "even more significant than without BN.\n",
    "\n",
    "Formally, BN transforms the activations at a given layer $\\mathbf{x}$\n",
    "according to the following expression:\n",
    "\n",
    "$$\\mathrm{BN}(\\mathbf{x}) = \\mathbf{\\gamma} \\odot \\frac{\\mathbf{x} - \\hat{\\mathbf{\\mu}}}{\\hat\\sigma} + \\mathbf{\\beta}$$\n",
    "\n",
    "Here, $\\hat{\\mathbf{\\mu}}$ is the minibatch sample mean\n",
    "and $\\hat{\\mathbf{\\sigma}}$ is the minibatch sample standard deviation.\n",
    "After applying BN, the resulting minibatch of activations\n",
    "has zero mean and unit variance.\n",
    "Because the choice of unit variance\n",
    "(vs some other magic number) is an arbitrary choice,\n",
    "we commonly include coordinate-wise\n",
    "scaling coefficients $\\mathbf{\\gamma}$ and offsets $\\mathbf{\\beta}$.\n",
    "Consequently, the activation magnitudes\n",
    "for intermediate layers cannot diverge during training\n",
    "because BN actively centers and rescales them back\n",
    "to a given mean and size (via $\\mathbf{\\mu}$ and $\\sigma$).\n",
    "One piece of practitioner's intuition/wisdom\n",
    "is that BN seems to allows for more aggressive learning rates.\n",
    "\n",
    "\n",
    "Formally, denoting a particular minibatch by $\\mathcal{B}$,\n",
    "we calculate $\\hat{\\mathbf{\\mu}}_\\mathcal{B}$ and $\\hat\\sigma_\\mathcal{B}$ as follows:\n",
    "\n",
    "$$\\hat{\\mathbf{\\mu}}_\\mathcal{B} \\leftarrow \\frac{1}{|\\mathcal{B}|} \\sum_{\\mathbf{x} \\in \\mathcal{B}} \\mathbf{x}\n",
    "\\text{ and }\n",
    "\\hat{\\mathbf{\\sigma}}_\\mathcal{B}^2 \\leftarrow \\frac{1}{|\\mathcal{B}|} \\sum_{\\mathbf{x} \\in \\mathcal{B}} (\\mathbf{x} - \\mathbf{\\mu}_{\\mathcal{B}})^2 + \\epsilon$$\n",
    "\n",
    "Note that we add a small constant $\\epsilon > 0$\n",
    "to the variance estimate\n",
    "to ensure that we never attempt division by zero,\n",
    "even in cases where the empirical variance estimate might vanish.\n",
    "The estimates $\\hat{\\mathbf{\\mu}}_\\mathcal{B}$\n",
    "and $\\hat{\\mathbf{\\sigma}}_\\mathcal{B}$ counteract the scaling issue\n",
    "by using noisy estimates of mean and variance.\n",
    "You might think that this noisiness should be a problem.\n",
    "As it turns out, this is actually beneficial.\n",
    "\n",
    "This turns out to be a recurring theme in deep learning.\n",
    "For reasons that are not yet well-characterized theoretically,\n",
    "various sources of noise in optimization\n",
    "often lead to faster training and less overfitting.\n",
    "While traditional machine learning theorists\n",
    "might buckle at this characterization,\n",
    "this variation appears to act as a form of regularization.\n",
    "In some preliminary research,\n",
    ":cite:`Teye.Azizpour.Smith.2018` and :cite:`Luo.Wang.Shao.ea.2018`\n",
    "relate the properties of BN to Bayesian Priors and penalties respectively.\n",
    "In particular, this sheds some light on the puzzle\n",
    "of why BN works best for moderate minibatches sizes in the $50$–$100$ range.\n",
    "\n",
    "Fixing a trained model, you might (rightly) think\n",
    "that we would prefer to use the entire dataset\n",
    "to estimate the mean and variance.\n",
    "Once training is complete, why would we want\n",
    "the same image to be classified differently,\n",
    "depending on the batch in which it happens to reside?\n",
    "During training, such exact calculation is infeasible\n",
    "because the activations for all data points\n",
    "change every time we update our model.\n",
    "However, once the model is trained,\n",
    "we can calculate the means and variances\n",
    "of each layer's activations based on the entire dataset.\n",
    "Indeed this is standard practice for\n",
    "models employing batch normalization\n",
    "and thus BN layers function differently\n",
    "in *training mode* (normalizing by minibatch statistics)\n",
    "and in *prediction mode* (normalizing by dataset statistics).\n",
    "\n",
    "We are now ready to take a look at how batch normalization works in practice.\n",
    "\n",
    "\n",
    "## Batch Normalization Layers\n",
    "\n",
    "Batch normalization implementations for fully-connected layers\n",
    "and convolutional layers are slightly different.\n",
    "We discuss both cases below.\n",
    "Recall that one key differences between BN and other layers\n",
    "is that because BN operates on a full minibatch at a time,\n",
    "we cannot just ignore the batch dimension\n",
    "as we did before when introducing other layers.\n",
    "\n",
    "\n",
    "### Fully-Connected Layers\n",
    "\n",
    "When applying BN to fully-connected layers,\n",
    "we usually insert BN after the affine transformation\n",
    "and before the nonlinear activation function.\n",
    "Denoting the input to the layer by $\\mathbf{x}$,\n",
    "the linear transform (with weights $\\theta$) by $f_{\\theta}(\\cdot)$,\n",
    "the activation function by $\\phi(\\cdot)$,\n",
    "and the BN operation with parameters $\\mathbf{\\beta}$ and $\\mathbf{\\gamma}$\n",
    "by $\\mathrm{BN}_{\\mathbf{\\beta}, \\mathbf{\\gamma}}$,\n",
    "we can express the computation of a BN-enabled,\n",
    "fully-connected layer $\\mathbf{h}$ as follows:\n",
    "\n",
    "$$\\mathbf{h} = \\phi(\\mathrm{BN}_{\\mathbf{\\beta}, \\mathbf{\\gamma}}(f_{\\mathbf{\\theta}}(\\mathbf{x}) ) ) $$\n",
    "\n",
    "Recall that mean and variance are computed\n",
    "on the *same* minibatch $\\mathcal{B}$\n",
    "on which the transformation is applied.\n",
    "Also recall that the scaling coefficient $\\mathbf{\\gamma}$\n",
    "and the offset $\\mathbf{\\beta}$ are parameters that need to be learned\n",
    "jointly with the more familiar parameters $\\mathbf{\\theta}$.\n",
    "\n",
    "### Convolutional Layers\n",
    "\n",
    "Similarly, with convolutional layers,\n",
    "we typically apply BN after the convolution\n",
    "and before the nonlinear activation function.\n",
    "When the convolution has multiple output channels,\n",
    "we need to carry out batch normalization\n",
    "for *each* of the outputs of these channels,\n",
    "and each channel has its own scale and shift parameters,\n",
    "both of which are scalars.\n",
    "Assume that our minibatches contain $m$ each\n",
    "and that for each channel,\n",
    "the output of the convolution has height $p$ and width $q$.\n",
    "For convolutional layers, we carry out each batch normalization\n",
    "over the $m \\cdot p \\cdot q$ elements per output channel simultaneously.\n",
    "Thus we collect the values over all spatial locations\n",
    "when computing the mean and variance\n",
    "and consequently (within a given channel)\n",
    "apply the same $\\hat{\\mathbf{\\mu}}$ and $\\hat{\\mathbf{\\sigma}}$\n",
    "to normalize the values at each spatial location.\n",
    "\n",
    "\n",
    "### Batch Normalization During Prediction\n",
    "\n",
    "As we mentioned earlier, BN typically behaves differently\n",
    "in training mode and prediction mode.\n",
    "First, the noise in $\\mathbf{\\mu}$ and $\\mathbf{\\sigma}$\n",
    "arising from estimating each on minibatches\n",
    "are no longer desirable once we have trained the model.\n",
    "Second, we might not have the luxury\n",
    "of computing per-batch normalization statistics, e.g.,\n",
    "we might need to apply our model to make one prediction at a time.\n",
    "\n",
    "Typically, after training, we use the entire dataset\n",
    "to compute stable estimates of the activation statistics\n",
    "and then fix them at prediction time.\n",
    "Consequently, BN behaves differently during training and at test time.\n",
    "Recall that dropout also exhibits this characteristic.\n",
    "\n",
    "## Implementation from Scratch\n",
    "\n",
    "Below, firstly we get all the relevant libraries needed to implement BatchNorm. After that, we implement a batch normalization layer with NDArrays from scratch:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils\n",
    "%load ../utils/Training.java\n",
    "%load ../utils/Accumulator.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ai.djl.basicdataset.cv.classification.*;\n",
    "import org.apache.commons.lang3.ArrayUtils;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 1,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "public NDList batchNormUpdate(NDArray X, NDArray gamma,\n",
    "                   NDArray beta, NDArray movingMean, NDArray movingVar,\n",
    "                   float eps, float momentum, boolean isTraining) {\n",
    "    // attach moving mean and var to submanager to close intermediate computation values\n",
    "    // at the end to avoid memory leak\n",
    "    try(NDManager subManager = movingMean.getManager().newSubManager()){\n",
    "        movingMean.attach(subManager);\n",
    "        movingVar.attach(subManager);\n",
    "        NDArray xHat;\n",
    "        NDArray mean;\n",
    "        NDArray var;\n",
    "        if (!isTraining) {\n",
    "            // If it is the prediction mode, directly use the mean and variance\n",
    "            // obtained from the incoming moving average\n",
    "            xHat = X.sub(movingMean).div(movingVar.add(eps).sqrt());\n",
    "        } else {\n",
    "            if (X.getShape().dimension() == 2) {\n",
    "                // When using a fully connected layer, calculate the mean and\n",
    "                // variance on the feature dimension\n",
    "                mean = X.mean(new int[]{0}, true);\n",
    "                var = X.sub(mean).pow(2).mean(new int[]{0}, true);\n",
    "            } else {\n",
    "                // When using a two-dimensional convolutional layer, calculate the\n",
    "                // mean and variance on the channel dimension (axis=1). Here we\n",
    "                // need to maintain the shape of `X`, so that the broadcast\n",
    "                // operation can be carried out later\n",
    "                mean = X.mean(new int[]{0, 2, 3}, true);\n",
    "                var = X.sub(mean).pow(2).mean(new int[]{0, 2, 3}, true);\n",
    "            }\n",
    "            // In training mode, the current mean and variance are used for the\n",
    "            // standardization\n",
    "            xHat = X.sub(mean).div(var.add(eps).sqrt());\n",
    "            // Update the mean and variance of the moving average\n",
    "            movingMean = movingMean.mul(momentum).add(mean.mul(1.0f - momentum));\n",
    "            movingVar = movingVar.mul(momentum).add(var.mul(1.0f - momentum));\n",
    "        }\n",
    "        NDArray Y = xHat.mul(gamma).add(beta); // Scale and shift\n",
    "        // attach moving mean and var back to original manager to keep their values\n",
    "        movingMean.attach(subManager.getParentManager());\n",
    "        movingVar.attach(subManager.getParentManager());\n",
    "        return new NDList(Y, movingMean, movingVar);\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 4
   },
   "source": [
    "\n",
    "We can now create a proper `BatchNorm` layer. Our layer will maintain proper parameters corresponding for scale gamma and shift beta, both of which will be updated in the course of training. Additionally, our layer will maintain a moving average of the means and variances for subsequent use during model prediction. The numFeatures parameter required by the BatchNorm instance is the number of outputs for a fully-connected layer and the number of output channels for a convolutional layer. The numDimensions parameter also required by this instance is 2 for a fully-connected layer and 4 for a convolutional layer.\n",
    "\n",
    "Putting aside the algorithmic details, note the design pattern underlying our implementation of the layer. Typically, we define the math in a separate function, say `batchNormUpdate`. We then integrate this functionality into a custom layer, whose code mostly addresses bookkeeping matters, such as moving data to the right device context, allocating and initializing any required variables, keeping track of running averages (here for mean and variance), etc. This pattern enables a clean separation of math from boilerplate code. Also note that for the sake of convenience we did not worry about automatically inferring the input shape here, thus we need to specify the number of features throughout. Do not worry, the DJL `BatchNorm` layer will care of this for us.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 5,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "public class BatchNormBlock extends AbstractBlock {\n",
    "\n",
    "    private NDArray movingMean;\n",
    "    private NDArray movingVar;\n",
    "    private Parameter gamma;\n",
    "    private Parameter beta;\n",
    "    private Shape shape;\n",
    "\n",
    "    // num_features: the number of outputs for a fully-connected layer\n",
    "    // or the number of output channels for a convolutional layer.\n",
    "    // num_dims: 2 for a fully-connected layer and 4 for a convolutional layer.\n",
    "    public BatchNormBlock(int numFeatures, int numDimensions) {\n",
    "        if (numDimensions == 2) {\n",
    "            shape = new Shape(1, numFeatures);\n",
    "        } else {\n",
    "            shape = new Shape(1, numFeatures, 1, 1);\n",
    "        }\n",
    "        // The scale parameter and the shift parameter involved in gradient\n",
    "        // finding and iteration are initialized to 0 and 1 respectively\n",
    "        gamma = addParameter(\n",
    "                    Parameter.builder()\n",
    "                        .setName(\"gamma\")\n",
    "                        .setType(Parameter.Type.GAMMA)\n",
    "                        .optShape(shape)\n",
    "                        .build());\n",
    "        \n",
    "        beta = addParameter(\n",
    "                    Parameter.builder()\n",
    "                        .setName(\"beta\")\n",
    "                        .setType(Parameter.Type.BETA)\n",
    "                        .optShape(shape)\n",
    "                        .build());\n",
    "\n",
    "        // All the variables not involved in gradient finding and iteration are\n",
    "        // initialized to 0. Create a base manager to maintain their values \n",
    "        // throughout the entire training process\n",
    "        NDManager manager = NDManager.newBaseManager();\n",
    "        movingMean = manager.zeros(shape);\n",
    "        movingVar = manager.zeros(shape);\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public String toString() {\n",
    "        return \"BatchNormBlock()\";\n",
    "    }\n",
    "    \n",
    "    @Override\n",
    "    protected NDList forwardInternal(\n",
    "            ParameterStore parameterStore,\n",
    "            NDList inputs,\n",
    "            boolean training,\n",
    "            PairList<String, Object> params) {\n",
    "        NDList result = batchNormUpdate(inputs.singletonOrThrow(),\n",
    "                gamma.getArray(), beta.getArray(), this.movingMean, this.movingVar, 1e-12f, 0.9f, training);\n",
    "        // close previous NDArray before assigning new values\n",
    "        if(training){\n",
    "            this.movingMean.close();\n",
    "            this.movingVar.close();\n",
    "        }\n",
    "        // Save the updated `movingMean` and `movingVar`\n",
    "        this.movingMean = result.get(1);\n",
    "        this.movingVar = result.get(2);\n",
    "        return new NDList(result.get(0));\n",
    "    }\n",
    "    \n",
    "    @Override\n",
    "    public Shape[] getOutputShapes(Shape[] inputs) {\n",
    "        Shape[] current = inputs;\n",
    "        for (Block block : children.values()) {\n",
    "            current = block.getOutputShapes(current);\n",
    "        }\n",
    "        return current;\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 8
   },
   "source": [
    "## Using a Batch Normalization LeNet\n",
    "\n",
    "To see how to apply `BatchNorm` in context,\n",
    "below we apply it to a traditional LeNet model (:numref:`sec_lenet`).\n",
    "Recall that BN is typically applied\n",
    "after the convolutional layers and fully-connected layers\n",
    "but before the corresponding activation functions.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 9,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "SequentialBlock net = new SequentialBlock()\n",
    "    .add(\n",
    "        Conv2d.builder()\n",
    "            .setKernelShape(new Shape(5, 5))\n",
    "            .setFilters(6).build())\n",
    "    .add(new BatchNormBlock(6, 4))\n",
    "    .add(Pool.maxPool2dBlock(new Shape(2, 2), new Shape(2, 2)))\n",
    "    .add(\n",
    "        Conv2d.builder()\n",
    "            .setKernelShape(new Shape(5, 5))\n",
    "            .setFilters(16).build())\n",
    "    .add(new BatchNormBlock(16, 4))\n",
    "    .add(Activation::sigmoid)\n",
    "    .add(Pool.maxPool2dBlock(new Shape(2, 2), new Shape(2, 2)))\n",
    "    .add(Blocks.batchFlattenBlock())\n",
    "    .add(Linear.builder().setUnits(120).build())\n",
    "    .add(new BatchNormBlock(120, 2))\n",
    "    .add(Activation::sigmoid)\n",
    "    .add(Blocks.batchFlattenBlock())\n",
    "    .add(Linear.builder().setUnits(84).build())\n",
    "    .add(new BatchNormBlock(84, 2))\n",
    "    .add(Activation::sigmoid)\n",
    "    .add(Linear.builder().setUnits(10).build());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Let's initialize the batchSize, numEpochs and the relevant arrays to store the data from the training function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "int batchSize = 256;\n",
    "int numEpochs = Integer.getInteger(\"MAX_EPOCH\", 10);\n",
    "        \n",
    "double[] trainLoss;\n",
    "double[] testAccuracy;\n",
    "double[] epochCount;\n",
    "double[] trainAccuracy;\n",
    "\n",
    "epochCount = new double[numEpochs];\n",
    "\n",
    "for (int i = 0; i < epochCount.length; i++) {\n",
    "    epochCount[i] = i+1;\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 12,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "As before, we will train our network on the Fashion-MNIST dataset.\n",
    "This code is virtually identical to that when we first trained LeNet (:numref:`sec_lenet`).\n",
    "The main difference is the considerably larger learning rate.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "FashionMnist trainIter = FashionMnist.builder()\n",
    "        .optUsage(Dataset.Usage.TRAIN)\n",
    "        .setSampling(batchSize, true)\n",
    "        .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "        .build();\n",
    "\n",
    "FashionMnist testIter = FashionMnist.builder()\n",
    "        .optUsage(Dataset.Usage.TEST)\n",
    "        .setSampling(batchSize, true)\n",
    "        .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "        .build();\n",
    "\n",
    "trainIter.prepare();\n",
    "testIter.prepare();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "float lr = 1.0f;\n",
    "\n",
    "Loss loss = Loss.softmaxCrossEntropyLoss();\n",
    "\n",
    "Tracker lrt = Tracker.fixed(lr);\n",
    "Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();\n",
    "\n",
    "DefaultTrainingConfig config = new DefaultTrainingConfig(loss)\n",
    "                .optOptimizer(sgd) // Optimizer (loss function)\n",
    "                .optDevices(Engine.getInstance().getDevices(1)) // single GPU\n",
    "                .addEvaluator(new Accuracy()) // Model Accuracy\n",
    "                .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n",
    "\n",
    "Model model = Model.newInstance(\"batch-norm\");\n",
    "model.setBlock(net);\n",
    "Trainer trainer = model.newTrainer(config);\n",
    "trainer.initialize(new Shape(1, 1, 28, 28));\n",
    "\n",
    "Map<String, double[]> evaluatorMetrics = new HashMap<>();\n",
    "double avgTrainTimePerEpoch = Training.trainingChapter6(trainIter, testIter, numEpochs, trainer, evaluatorMetrics);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "trainLoss = evaluatorMetrics.get(\"train_epoch_SoftmaxCrossEntropyLoss\");\n",
    "trainAccuracy = evaluatorMetrics.get(\"train_epoch_Accuracy\");\n",
    "testAccuracy = evaluatorMetrics.get(\"validate_epoch_Accuracy\");\n",
    "\n",
    "System.out.printf(\"loss %.3f,\", trainLoss[numEpochs - 1]);\n",
    "System.out.printf(\" train acc %.3f,\", trainAccuracy[numEpochs - 1]);\n",
    "System.out.printf(\" test acc %.3f\\n\", testAccuracy[numEpochs - 1]);\n",
    "System.out.printf(\"%.1f examples/sec\", trainIter.size() / (avgTrainTimePerEpoch / Math.pow(10, 9)));\n",
    "System.out.println();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 15,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Let us have a look at the scale parameter `gamma`\n",
    "and the shift parameter `beta` learned\n",
    "from the first batch normalization layer.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 16,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "// Printing the value of gamma and beta in the first BatchNorm layer.\n",
    "List<Parameter> batchNormFirstParams = net.getChildren().values().get(1).getParameters().values();\n",
    "System.out.println(\"gamma \" + batchNormFirstParams.get(0).getArray().reshape(-1));\n",
    "System.out.println(\"beta \" + batchNormFirstParams.get(1).getArray().reshape(-1));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Contour Gradient Descent.](https://d2l-java-resources.s3.amazonaws.com/img/chapter_convolution-modern-cnn-batchnorm1.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];\n",
    "\n",
    "Arrays.fill(lossLabel, 0, trainLoss.length, \"train loss\");\n",
    "Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, \"train acc\");\n",
    "Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,\n",
    "                trainLoss.length + testAccuracy.length + trainAccuracy.length, \"test acc\");\n",
    "\n",
    "Table data = Table.create(\"Data\").addColumns(\n",
    "    DoubleColumn.create(\"epoch\", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),\n",
    "    DoubleColumn.create(\"metrics\", ArrayUtils.addAll(trainLoss, ArrayUtils.addAll(trainAccuracy, testAccuracy))),\n",
    "    StringColumn.create(\"lossLabel\", lossLabel)\n",
    ");\n",
    "\n",
    "render(LinePlot.create(\"\", data, \"epoch\", \"metrics\", \"lossLabel\"),\"text/html\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 19
   },
   "source": [
    "## Concise Implementation\n",
    "\n",
    "Compared with the `BatchNorm` class, which we just defined ourselves, the `BatchNorm` class defined by `nn` module in DJL is easier to use. In DJL, we do not have to worry about `numFeatures` or `numDimensions`. Instead, these parameter values will be inferred automatically via delayed initialization. Otherwise, the code looks virtually identical to the application our implementation above.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 20,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "SequentialBlock block = new SequentialBlock()\n",
    "    .add(\n",
    "            Conv2d.builder()\n",
    "                    .setKernelShape(new Shape(5, 5))\n",
    "                    .setFilters(6).build())\n",
    "    .add(BatchNorm.builder().build())\n",
    "    .add(Pool.maxPool2dBlock(new Shape(2, 2), new Shape(2, 2)))\n",
    "    .add(\n",
    "            Conv2d.builder()\n",
    "                    .setKernelShape(new Shape(5, 5))\n",
    "                    .setFilters(16).build())\n",
    "    .add(BatchNorm.builder().build())\n",
    "    .add(Activation::sigmoid)\n",
    "    .add(Pool.maxPool2dBlock(new Shape(2, 2), new Shape(2, 2)))\n",
    "    .add(Blocks.batchFlattenBlock())\n",
    "    .add(Linear.builder().setUnits(120).build())\n",
    "    .add(BatchNorm.builder().build())\n",
    "    .add(Activation::sigmoid)\n",
    "    .add(Blocks.batchFlattenBlock())\n",
    "    .add(Linear.builder().setUnits(84).build())\n",
    "    .add(BatchNorm.builder().build())\n",
    "    .add(Activation::sigmoid)\n",
    "    .add(Linear.builder().setUnits(10).build());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 23,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Below, we use the same hyperparameters to train out model.\n",
    "Note that as usual, the high-level API variant runs much faster\n",
    "because its code has been compiled to C++/CUDA\n",
    "while our custom implementation must be interpreted by Python.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 24,
    "pycharm": {
     "name": "#%%\n"
    },
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "Loss loss = Loss.softmaxCrossEntropyLoss();\n",
    "\n",
    "Tracker lrt = Tracker.fixed(1.0f);\n",
    "Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();\n",
    "\n",
    "Model model = Model.newInstance(\"batch-norm\");\n",
    "model.setBlock(block);\n",
    "\n",
    "DefaultTrainingConfig config = new DefaultTrainingConfig(loss)\n",
    "                .optOptimizer(sgd) // Optimizer (loss function)\n",
    "                .addEvaluator(new Accuracy()) // Model Accuracy\n",
    "                .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n",
    "\n",
    "Trainer trainer = model.newTrainer(config);\n",
    "trainer.initialize(new Shape(1, 1, 28, 28));\n",
    "\n",
    "Map<String, double[]> evaluatorMetrics = new HashMap<>();\n",
    "double avgTrainTimePerEpoch = 0;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "avgTrainTimePerEpoch = Training.trainingChapter6(trainIter, testIter, numEpochs, trainer, evaluatorMetrics);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "trainLoss = evaluatorMetrics.get(\"train_epoch_SoftmaxCrossEntropyLoss\");\n",
    "trainAccuracy = evaluatorMetrics.get(\"train_epoch_Accuracy\");\n",
    "testAccuracy = evaluatorMetrics.get(\"validate_epoch_Accuracy\");\n",
    "\n",
    "System.out.printf(\"loss %.3f,\", trainLoss[numEpochs - 1]);\n",
    "System.out.printf(\" train acc %.3f,\", trainAccuracy[numEpochs - 1]);\n",
    "System.out.printf(\" test acc %.3f\\n\", testAccuracy[numEpochs - 1]);\n",
    "System.out.printf(\"%.1f examples/sec\", trainIter.size() / (avgTrainTimePerEpoch / Math.pow(10, 9)));\n",
    "System.out.println();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Contour Gradient Descent.](https://d2l-java-resources.s3.amazonaws.com/img/chapter_convolution-modern-cnn-batchnorm2.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];\n",
    "\n",
    "Arrays.fill(lossLabel, 0, trainLoss.length, \"train loss\");\n",
    "Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, \"train acc\");\n",
    "Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,\n",
    "                trainLoss.length + testAccuracy.length + trainAccuracy.length, \"test acc\");\n",
    "\n",
    "Table data = Table.create(\"Data\").addColumns(\n",
    "            DoubleColumn.create(\"epoch\", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),\n",
    "            DoubleColumn.create(\"metrics\", ArrayUtils.addAll(trainLoss, ArrayUtils.addAll(trainAccuracy, testAccuracy))),\n",
    "            StringColumn.create(\"lossLabel\", lossLabel)\n",
    ");\n",
    "\n",
    "render(LinePlot.create(\"\", data, \"epoch\", \"metrics\", \"lossLabel\"),\"text/html\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Controversy\n",
    "\n",
    "Intuitively, batch normalization is thought\n",
    "to make the optimization landscape smoother.\n",
    "However, we must be careful to distinguish between\n",
    "speculative intuitions and true explanations\n",
    "for the phenomena that we observe when training deep models.\n",
    "Recall that we do not even know why simpler\n",
    "deep neural networks (MLPs and conventional CNNs)\n",
    "generalize well in the first place.\n",
    "Even with dropout and $L_2$ regularization,\n",
    "they remain so flexible that their ability to generalize to unseen data\n",
    "cannot be explained via conventional learning-theoretic generalization guarantees.\n",
    "\n",
    "In the original paper proposing batch normalization,\n",
    "the authors, in addition to introducing a powerful and useful tool,\n",
    "offered an explanation for why it works:\n",
    "by reducing *internal covariate shift*.\n",
    "Presumably by *internal covariate shift* the authors\n",
    "meant something like the intuition expressed above—the\n",
    "notion that the distribution of activations changes\n",
    "over the course of training.\n",
    "However there were two problems with this explanation:\n",
    "(1) This drift is very different from *covariate shift*,\n",
    "rendering the name a misnomer.\n",
    "(2) The explanation offers an under-specified intuition\n",
    "but leaves the question of *why precisely this technique works*\n",
    "an open question wanting for a rigorous explanation.\n",
    "Throughout this book, we aim to convey the intuitions that practitioners\n",
    "use to guide their development of deep neural networks.\n",
    "However, we believe that it is important\n",
    "to separate these guiding intuitions\n",
    "from established scientific fact.\n",
    "Eventually, when you master this material\n",
    "and start writing your own research papers\n",
    "you will want to be clear to delineate\n",
    "between technical claims and hunches.\n",
    "\n",
    "Following the success of batch normalization,\n",
    "its explanation in terms of *internal covariate shift*\n",
    "has repeatedly surfaced in debates in the technical literature\n",
    "and broader discourse about how to present machine learning research.\n",
    "In a memorable speech given while accepting a Test of Time Award\n",
    "at the 2017 NeurIPS conference,\n",
    "Ali Rahimi used *internal covariate shift*\n",
    "as a focal point in an argument likening\n",
    "the modern practice of deep learning to alchemy.\n",
    "Subsequently, the example was revisited in detail\n",
    "in a position paper outlining\n",
    "troubling trends in machine learning :cite:`Lipton.Steinhardt.2018`.\n",
    "In the technical literature other authors (:cite:`Santurkar.Tsipras.Ilyas.ea.2018`)\n",
    "have proposed alternative explanations for the success of BN,\n",
    "some claiming that BN's success comes despite exhibiting behavior\n",
    "that is in some ways opposite to those claimed in the original paper.\n",
    "\n",
    "We note that the *internal covariate shift*\n",
    "is no more worthy of criticism than any of\n",
    "thousands of similarly vague claims\n",
    "made every year in the technical ML literature.\n",
    "Likely, its resonance as a focal point of these debates\n",
    "owes to its broad recognizability to the target audience.\n",
    "Batch normalization has proven an indispensable method,\n",
    "applied in nearly all deployed image classifiers,\n",
    "earning the paper that introduced the technique\n",
    "tens of thousands of citations.\n",
    "\n",
    "\n",
    "## Summary\n",
    "\n",
    "* During model training, batch normalization continuously adjusts the intermediate output of the neural network by utilizing the mean and standard deviation of the minibatch, so that the values of the intermediate output in each layer throughout the neural network are more stable.\n",
    "* The batch normalization methods for fully connected layers and convolutional layers are slightly different.\n",
    "* Like a dropout layer, batch normalization layers have different computation results in training mode and prediction mode.\n",
    "* Batch Normalization has many beneficial side effects, primarily that of regularization. On the other hand, the original motivation of reducing covariate shift seems not to be a valid explanation.\n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. Can we remove the fully connected affine transformation before the batch normalization or the bias parameter in convolution computation?\n",
    "    * Find an equivalent transformation that applies prior to the fully connected layer.\n",
    "    * Is this reformulation effective. Why (not)?\n",
    "1. Compare the learning rates for LeNet with and without batch normalization.\n",
    "    * Plot the decrease in training and test error.\n",
    "    * What about the region of convergence? How large can you make the learning rate?\n",
    "1. Do we need Batch Normalization in every layer? Experiment with it?\n",
    "1. Can you replace Dropout by Batch Normalization? How does the behavior change?\n",
    "1. Fix the coefficients `beta` and `gamma` , and observe and analyze the results.\n",
    "1. Review the online documentation for `BatchNorm` to see the other applications for Batch Normalization.\n",
    "1. Research ideas: think of other normalization transforms that you can apply? Can you apply the probability integral transform? How about a full rank covariance estimate?\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
